#include "allocate.h"
#include "error.h"
#include "internal.h"
#include "logging.h"

#include <string.h>

#ifdef LOGGING
#include <inttypes.h>
#endif

static void
inc_bytes (kissat * solver, size_t bytes)
{
#ifndef NMETRICS

  if (!solver)
    return;
  ADD (allocated_current, bytes);
  LOG5 ("allocated_current = %s",
	FORMAT_BYTES (solver->statistics.allocated_current));
  if (solver->statistics.allocated_current >=
      solver->statistics.allocated_max)
    {
      solver->statistics.allocated_max = solver->statistics.allocated_current;
      LOG5 ("allocated_max = %s",
	    FORMAT_BYTES (solver->statistics.allocated_max));
    }
#else
  (void) solver;
  (void) bytes;
#endif
}

static void
dec_bytes (kissat * solver, size_t bytes)
{
#ifndef NMETRICS
  if (!solver)
    return;
  SUB (allocated_current, bytes);
  LOG5 ("allocated_current = %s",
	FORMAT_BYTES (solver->statistics.allocated_current));
#else
  (void) solver;
  (void) bytes;
#endif
}

void *
kissat_inc_malloc (kissat * solver, size_t bytes)
{
  void *res;
  if (!bytes)
    return 0;
  res = malloc (bytes);
  LOG4 ("malloc (%zu) = %p", bytes, res);
  if (!res)
    kissat_inc_fatal ("out-of-memory allocating %zu bytes", bytes);
  inc_bytes (solver, bytes);
  return res;
}

char *
kissat_inc_strdup (kissat * solver, const char *str)
{
  char *res = kissat_inc_malloc (solver, strlen (str) + 1);
  return strcpy (res, str);
}

void *
kissat_inc_calloc (kissat * solver, size_t n, size_t size)
{
  void *res;
  if (!n || !size)
    return 0;
  if (MAX_SIZE_T / size < n)
    kissat_inc_fatal ("invalid 'kissat_inc_calloc (..., %zu, %zu)' call", n, size);
  res = calloc (n, size);
  LOG4 ("calloc (%zu, %zu) = %p", n, size, res);
  const size_t bytes = n * size;
  if (!res)
    kissat_inc_fatal ("out-of-memory allocating "
		  "%zu = %zu x %zu bytes", bytes, n, size);
  inc_bytes (solver, bytes);
  return res;
}

void
kissat_inc_free (kissat * solver, void *ptr, size_t bytes)
{
  if (ptr)
    {
      LOG4 ("free (%p[%zu])", ptr, bytes);
      dec_bytes (solver, bytes);
      free (ptr);
    }
  else
    assert (!bytes);
}

void *
kissat_inc_realloc (kissat * solver, void *p, size_t old_bytes, size_t new_bytes)
{
  if (old_bytes == new_bytes)
    return p;
  if (!new_bytes)
    {
      kissat_inc_free (solver, p, old_bytes);
      return 0;
    }
  dec_bytes (solver, old_bytes);
  void *res = realloc (p, new_bytes);
  LOG4 ("realloc (%p[%zu], %zu) = %p", p, old_bytes, new_bytes, res);
  if (new_bytes && !res)
    kissat_inc_fatal ("out-of-memory reallocating from %zu to %zu bytes",
		  old_bytes, new_bytes);
  inc_bytes (solver, new_bytes);
  return res;
}

void *
kissat_inc_nrealloc (kissat * solver, void *p, size_t o, size_t n, size_t size)
{
  if (!size)
    {
      assert (!p);
      assert (!o);
      return 0;
    }
  const size_t max = MAX_SIZE_T / size;
  if (max < o || max < n)
    kissat_inc_fatal ("invalid 'kissat_inc_nrealloc (..., %zu, %zu, %zu)' call",
		  o, n, size);
  return kissat_inc_realloc (solver, p, o * size, n * size);
}

void
kissat_inc_dealloc (kissat * solver, void *ptr, size_t n, size_t size)
{
  if (!n || !size)
    return;
  if (MAX_SIZE_T / size < n)
    kissat_inc_fatal ("invalid 'kissat_inc_dealloc (..., %zu, %zu)' call", n, size);
  const size_t bytes = n * size;
  kissat_inc_free (solver, ptr, bytes);
}

void
kissat_inc_delstr (kissat * solver, char *str)
{
  assert (str);
  kissat_inc_free (solver, str, strlen (str) + 1);
}
